import numpy as np
import cv2
import scipy
import sympy
from scipy.optimize import check_grad, approx_fprime
from sklearn.decomposition import PCA

cs = 0.012

# intrinsics = {'fx': 923.696, 'fy': 923.115, 'height': 720, 'width': 1280, 'ppx': 630.634, 'ppy': 367.255}
depth_intrinsics = {'width': 1280, 'height': 720, 'ppx': 641.108, 'ppy': 357.824, 'fx': 637.89, 'fy': 637.89}
intrinsics = {'width': 640, 'height': 480, 'ppx': 313.756, 'ppy': 244.836, 'fx': 615.797, 'fy': 615.41}


def set_intrinsics(intr):
    """Sets the intrinsics for the module.
    This only works for threaded version (not Process version). And it is not very nice way to hangle this

    :param intr: dict as given in the default intrinsics
    :return:
    """

    global intrinsics
    intrinsics = intr

def generate_checkerboard_points(x_size, y_size, dim):
    """ Generates checkerboard points on a plane with z = 0

    :param x_size: number of corners in the x-direction
    :param y_size: number of corners in the y-direction
    :param dim: dimension of the square
    :return: numpy array of points
    """
    pts = []
    for y in range(y_size):
        for x in range(x_size):
            pts.append([0.0 + dim * x, 0.0 + dim * y, 0.0])
    return np.array(pts)


def get_Rt(orig, dst):
    """ Returns the rotation matrix and translation for rotation on pair od 3D point sets

    This function uses the Kabschs algorithm (https://en.wikipedia.org/wiki/Kabsch_algorithm)

    The output can be used as: dst = (R @ orig.T).T + t

    :param orig: np array of points in the original coordinate system
    :param dst: np array of points in the desired coordinate system
    :return: R - rotation matrix 3 x 3, t - 3d translation vector
    """
    new_dst = dst - np.mean(dst, axis=0)
    new_orig = orig - np.mean(orig, axis=0)

    H = new_orig.T @ new_dst
    # H = new_dst.T @ new_orig
    U, S, V = np.linalg.svd(H)
    d = np.linalg.det(V @ U.T)
    I_d = np.eye(3, 3)
    I_d[2, 2] = np.sign(d)

    R = V @ I_d @ U.T

    t = np.mean(dst, axis=0) - (R @ np.mean(orig, axis=0))

    print(np.mean(np.sqrt(np.sum(((R @ orig.T).T + t - dst)**2, axis=-1))))

    return R, t


def get_interpolated_xyz(xyz, corners):
    """ Interpolate the xyz positions of corners using the bilinear interpolation

    The algorithm used: https://en.wikipedia.org/wiki/Bilinear_interpolation#Nonlinear
    :param xyz: 3d pointcloud ordered as an image
    :param corners: output from cv2.findChessboardCorners
    :return: n x 3 array of points of interpolated xyz values
    """
    idxs = corners[:, 0, ::-1]
    idxs_lower = np.floor(idxs).astype(np.int)
    idxs_upper = np.ceil(idxs).astype(np.int)
    idxs_x = idxs - idxs_lower

    a00 = xyz[idxs_lower[:, 0], idxs_lower[:, 1]]
    a10 = xyz[idxs_upper[:, 0], idxs_lower[:, 1]] - a00
    a01 = xyz[idxs_lower[:, 0], idxs_upper[:, 1]] - a00
    a11 = xyz[idxs_upper[:, 0], idxs_upper[:, 1]] + a00 - (xyz[idxs_lower[:, 0], idxs_upper[:, 1]] + xyz[idxs_upper[:, 0], idxs_lower[:, 1]])

    orig_points = a00 + a10 * idxs_x[:, 0, np.newaxis] + a01 * idxs_x[:, 1, np.newaxis] + a11 * idxs_x[:, 0, np.newaxis] * idxs_x[:, 1, np.newaxis]
    return orig_points


def get_checkerboard_extrinsics(rgb, xyz, orig_shift=None, debug=False):
    """ Return rotation matrix and translation vector for coordinate system described by the checkerboard pattern

    :param rgb: RGB image
    :param xyz: XYZ pointcloud ordered as the RGB image
    :param orig_shift: 3d vector of the position of the first corner in the pattern in the new coordinate system
    :param debug: bool - display debug images
    :return: R - rotation matrix 3 x 3, t - 3d translation vector
    """

    bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
    found, corners = cv2.findChessboardCorners(bgr, (9, 6))
    if not found:
        if debug:
            print("Checkerboard not detected!")
        return None, None

    if debug:
        bgr_c = bgr.copy()
        cv2.drawChessboardCorners(bgr_c, (9, 6), corners, found)
        cv2.imshow("Found corners", bgr_c)

    dst = generate_checkerboard_points(9, 6, 0.022)
    if orig_shift is not None:
        dst = dst + orig_shift

    orig_points = get_interpolated_xyz(xyz, corners)

    l = np.where(orig_points[:, 2] > 0)
    orig_points = orig_points[l]
    dst = dst[l]

    if debug:
        bgr_c = bgr.copy()
        for point in orig_points:
            u, v = get_pixel_coords(point)
            cv2.circle(bgr_c, (u, v), 3, color=(0, 0, 255), thickness=-1)
        cv2.imshow("Orig reprojection", bgr_c)

        # perform PCA to check if the points lie on a plane
        pca = PCA()
        pca.fit(orig_points)
        # the last value should be much smaller than the other two
        print(pca.explained_variance_ratio_)

    R, t = get_Rt(orig_points, dst)

    if debug:
        new_points = (np.linalg.inv(R) @ (dst - t).T).T
        bgr_c = bgr.copy()
        for point in new_points:
            u, v = get_pixel_coords(point)
            cv2.circle(bgr_c, (u, v), 3, color=(0, 0, 255), thickness=-1)
        cv2.imshow("Verification", bgr_c)
        cv2.waitKey(0)

    return R, t


def get_pixel_coords(point):
    """ Returns pixel coordinates for a given 3D point

    :param point: 3D point
    :return: u, v - pixel coordinates
    """
    x = point[0] / (point[2])
    y = point[1] / (point[2])
    u = intrinsics['fx'] * x + intrinsics['ppx']
    v = intrinsics['fy'] * y + intrinsics['ppy']
    return int(u), int(v)


def create_xyz(depth):
    """ Returns the array of xyz coordinates for the given depth map

    :param depth: depth map of shape height * width
    :return: xyz - np array of shape height * width * 3
    """
    depth = depth / 1000
    height = intrinsics['height']
    width = intrinsics['width']
    xyz = np.zeros([height, width, 3], dtype=np.float32)
    xyz[:, :, 2] = depth
    mg = np.mgrid[0: height, 0: width]
    xyz[:, :, 0] = (mg[1, :, :] - intrinsics['ppx']) * depth / intrinsics['fx']
    xyz[:, :, 1] = (mg[0, :, :] - intrinsics['ppy']) * depth / intrinsics['fy']
    return xyz.astype(np.float32)


def check_pixel_coords(xyz):
    """ Checker method to verify correct fx and fy values

    :param xyz:
    :return:
    """
    Ax = np.array([0, 0, 0])
    for i in range(xyz.shape[1]):
        for j in range(xyz.shape[0]):
            p = xyz[i, j, :]
            u, v = get_pixel_coords(p)
            Ax = np.vstack([Ax, np.array([p[0] / p[2], 1, j])])
            if i > 4:
                print(np.linalg.solve(Ax[-3:-1, :2], Ax[-3:-1, 2]))


def draw_object(obj, c_img):
    if obj['type'] == 'cube' or obj['type'] == 'cuboid':
        return draw_corners(obj['corners'], c_img, obj['conf'])
    elif obj['type'] == 'grasper':
        return draw_grasper(obj['points'], c_img)
    else:
        return c_img

def draw_corners(corners, c_img, conf=1.0):
    """Adds cube to the image

    :param corners: list of 8 3d vectors representing corners
    :param c_img: m x n x 3 BGR image
    :param conf: confidence of the detection
    :return: m x n x 3 BGR image with the cube drawn
    """
    if corners is None or len(corners) == 0:
        return c_img
    for corner in corners:
        c_img = cv2.circle(c_img, get_pixel_coords(corner), 3, (0, 0, 255), thickness=-1)

    line_color = (0, conf * 255, (1 - conf) * 255)

    c_img = cv2.line(c_img, get_pixel_coords(corners[0]), get_pixel_coords(corners[1]), line_color, thickness=2)
    c_img = cv2.line(c_img, get_pixel_coords(corners[1]), get_pixel_coords(corners[2]), line_color, thickness=2)
    c_img = cv2.line(c_img, get_pixel_coords(corners[2]), get_pixel_coords(corners[3]), line_color, thickness=2)
    c_img = cv2.line(c_img, get_pixel_coords(corners[3]), get_pixel_coords(corners[0]), line_color, thickness=2)

    c_img = cv2.line(c_img, get_pixel_coords(corners[4]), get_pixel_coords(corners[5]), line_color, thickness=2)
    c_img = cv2.line(c_img, get_pixel_coords(corners[5]), get_pixel_coords(corners[6]), line_color, thickness=2)
    c_img = cv2.line(c_img, get_pixel_coords(corners[6]), get_pixel_coords(corners[7]), line_color, thickness=2)
    c_img = cv2.line(c_img, get_pixel_coords(corners[7]), get_pixel_coords(corners[4]), line_color, thickness=2)

    c_img = cv2.line(c_img, get_pixel_coords(corners[0]), get_pixel_coords(corners[4]), line_color, thickness=2)
    c_img = cv2.line(c_img, get_pixel_coords(corners[1]), get_pixel_coords(corners[5]), line_color, thickness=2)
    c_img = cv2.line(c_img, get_pixel_coords(corners[2]), get_pixel_coords(corners[6]), line_color, thickness=2)
    c_img = cv2.line(c_img, get_pixel_coords(corners[3]), get_pixel_coords(corners[7]), line_color, thickness=2)

    c_img = cv2.circle(c_img, get_pixel_coords(corners[0]), 3, (0, 255, 255), thickness=-1)

    return c_img


def draw_grasper(grasper_points, c_img):
    """ Draws two parallel lines for grasper the points are in this order
        ____       _____
       \    |     |    /
        \   |1   3|   /
         \  |     |  /
          \_|0   2|_/

    Top points are 10 mm above the bottom ones

    :param grasper_points: 4 x 3 array of points as in the picture
    :param c_img: bgr image of the scene
    :return: image with lines added
    """
    if grasper_points is None:
        return c_img

    c_img = cv2.circle(c_img, get_pixel_coords(grasper_points[0]), 3, (0, 255, 255), thickness=-1)
    c_img = cv2.circle(c_img, get_pixel_coords(grasper_points[2]), 3, (0, 255, 255), thickness=-1)

    c_img = cv2.circle(c_img, get_pixel_coords(grasper_points[1]), 3, (0, 0, 255), thickness=-1)
    c_img = cv2.circle(c_img, get_pixel_coords(grasper_points[3]), 3, (0, 0, 255), thickness=-1)

    c_img = cv2.line(c_img, get_pixel_coords(grasper_points[0]), get_pixel_coords(grasper_points[1]),(0, 255, 0))
    c_img = cv2.line(c_img, get_pixel_coords(grasper_points[2]), get_pixel_coords(grasper_points[3]),(0, 255, 0))

    return c_img


if __name__ == "__main__":
    for i in range(5, 50):
        ar = np.load("collect_checkerboard/{}.npy".format(i))
        rgb = ar[:, :, :3].astype('uint8')
        depth = ar[:, :, 3]
        xyz = create_xyz(depth)
        get_checkerboard_extrinsics(rgb, xyz, debug=True)
